#!/usr/bin/env python3


import random
import numpy as np
from PIL.Image import LANCZOS

import torch
from torchvision import transforms

import learn2learn as l2l
from learn2learn.data.transforms import NWays, KShots, LoadData, RemapLabels, ConsecutiveLabels, FilterLabels
from learn2learn.vision.transforms import RandomClassRotation


def get_omniglot(ways, shots):
    omniglot = l2l.vision.datasets.FullOmniglot(root='~/data',
                                                transform=transforms.Compose([
                                                    transforms.Resize(28, interpolation=LANCZOS),
                                                    transforms.ToTensor(),
                                                    lambda x: 1.0 - x,
                                                ]),
                                                download=True)
    dataset = l2l.data.MetaDataset(omniglot)
    classes = list(range(1623))
    random.shuffle(classes)

    train_transforms = [
        FilterLabels(dataset, classes[:1100]),
        NWays(dataset, ways),
        KShots(dataset, 2*shots),
        LoadData(dataset),
        RemapLabels(dataset),
        ConsecutiveLabels(dataset),
        RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    train_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        FilterLabels(dataset, classes[1100:1200]),
        NWays(dataset, ways),
        KShots(dataset, 2*shots),
        LoadData(dataset),
        RemapLabels(dataset),
        ConsecutiveLabels(dataset),
        RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    valid_tasks = l2l.data.TaskDataset(dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=1024)

    test_transforms = [
        FilterLabels(dataset, classes[1200:]),
        NWays(dataset, ways),
        KShots(dataset, 2*shots),
        LoadData(dataset),
        RemapLabels(dataset),
        ConsecutiveLabels(dataset),
        RandomClassRotation(dataset, [0.0, 90.0, 180.0, 270.0])
    ]
    test_tasks = l2l.data.TaskDataset(dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=1024)

    return train_tasks, valid_tasks, test_tasks


def get_mini_imagenet(ways, shots):
    # Create Datasets
    train_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='train', download=True)
    valid_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='validation', download=True)
    test_dataset = l2l.vision.datasets.MiniImagenet(root='~/data', mode='test', download=True)
    train_dataset = l2l.data.MetaDataset(train_dataset)
    valid_dataset = l2l.data.MetaDataset(valid_dataset)
    test_dataset = l2l.data.MetaDataset(test_dataset)

    train_transforms = [
        NWays(train_dataset, ways),
        KShots(train_dataset, 2 * shots),
        LoadData(train_dataset),
        RemapLabels(train_dataset),
        ConsecutiveLabels(train_dataset),
    ]
    train_tasks = l2l.data.TaskDataset(train_dataset,
                                       task_transforms=train_transforms,
                                       num_tasks=20000)

    valid_transforms = [
        NWays(valid_dataset, ways),
        KShots(valid_dataset, 2 * shots),
        LoadData(valid_dataset),
        ConsecutiveLabels(valid_dataset),
        RemapLabels(valid_dataset),
    ]
    valid_tasks = l2l.data.TaskDataset(valid_dataset,
                                       task_transforms=valid_transforms,
                                       num_tasks=600)

    test_transforms = [
        NWays(test_dataset, ways),
        KShots(test_dataset, 2 * shots),
        LoadData(test_dataset),
        RemapLabels(test_dataset),
        ConsecutiveLabels(test_dataset),
    ]
    test_tasks = l2l.data.TaskDataset(test_dataset,
                                      task_transforms=test_transforms,
                                      num_tasks=600)

    return train_tasks, valid_tasks, test_tasks


def prepare_batch(batch, shots, ways, device, features=None):
    data, labels = batch
    data, labels = data.to(device), labels.to(device)
    if features is not None:
        data = features(data)

    # Separate data into adaptation/evalutation sets
    adaptation_indices = np.zeros(data.size(0), dtype=bool)
    adaptation_indices[np.arange(shots * ways) * 2] = True
    evaluation_indices = torch.from_numpy(~adaptation_indices)
    adaptation_indices = torch.from_numpy(adaptation_indices)
    adaptation_data, adaptation_labels = data[adaptation_indices], labels[adaptation_indices]
    evaluation_data, evaluation_labels = data[evaluation_indices], labels[evaluation_indices]

    return adaptation_data, adaptation_labels, evaluation_data, evaluation_labels

